Identifying model bias¶
In this second part we will use explanation methods to identify a faulty classifier that was trained on biased data. Specifically, each image contains an artifact whose color is related to the class of the image. A model trained with such images will likely learn to disregard the image content entirely and only focus on the artifact to make a prediction. You will use one of the explanation methods implemented in the first part to spot the issue.
Altough in this example the bias was introduced artificially, it's not uncommon to see this kind of telling artifacts in real-world datasets. For example, in a dataset of X-ray scans, one might find identifiers along the edge or marks left by doctors that could hinder the learning of a model.
Setup¶
!pip install "jax[cuda]" -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html'
!pip install \
flax optax \
'git+https://github.com/n2cholas/jax-resnet.git' \
tensorflow-datasets \
better_exceptions
Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html Requirement already satisfied: jax[cuda] in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (0.6.2) Requirement already satisfied: jaxlib<=0.6.2,>=0.6.2 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax[cuda]) (0.6.2) Requirement already satisfied: ml_dtypes>=0.5.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax[cuda]) (0.5.3) Requirement already satisfied: numpy>=1.26 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax[cuda]) (1.26.4) Requirement already satisfied: opt_einsum in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax[cuda]) (3.4.0) Requirement already satisfied: scipy>=1.12 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax[cuda]) (1.15.3) INFO: pip is looking at multiple versions of jax[cuda] to determine which version is compatible with other requirements. This could take a while. Collecting jax[cuda] Using cached jax-0.6.1-py3-none-any.whl.metadata (13 kB) Collecting jaxlib<=0.6.1,>=0.6.1 (from jax[cuda]) Using cached jaxlib-0.6.1-cp310-cp310-macosx_11_0_arm64.whl.metadata (1.2 kB) Collecting jax[cuda] Using cached jax-0.6.0-py3-none-any.whl.metadata (22 kB) Collecting jaxlib<=0.6.0,>=0.6.0 (from jax[cuda]) Using cached jaxlib-0.6.0-cp310-cp310-macosx_11_0_arm64.whl.metadata (1.2 kB) Collecting jax[cuda] Using cached jax-0.5.3-py3-none-any.whl.metadata (22 kB) Collecting jaxlib<=0.5.3,>=0.5.3 (from jax[cuda]) Using cached jaxlib-0.5.3-cp310-cp310-macosx_11_0_arm64.whl.metadata (1.2 kB) Collecting jax[cuda] Using cached jax-0.5.2-py3-none-any.whl.metadata (22 kB) Collecting jaxlib<=0.5.2,>=0.5.1 (from jax[cuda]) Using cached jaxlib-0.5.1-cp310-cp310-macosx_11_0_arm64.whl.metadata (978 bytes) Collecting jax[cuda] Using cached jax-0.5.1-py3-none-any.whl.metadata (22 kB) Using cached jax-0.5.0-py3-none-any.whl.metadata (22 kB) Collecting jaxlib<=0.5.0,>=0.5.0 (from jax[cuda]) Using cached jaxlib-0.5.0-cp310-cp310-macosx_11_0_arm64.whl.metadata (978 bytes) Collecting jax[cuda] Using cached jax-0.4.38-py3-none-any.whl.metadata (22 kB) Collecting jaxlib<=0.4.38,>=0.4.38 (from jax[cuda]) Using cached jaxlib-0.4.38-cp310-cp310-macosx_11_0_arm64.whl.metadata (1.0 kB) INFO: pip is still looking at multiple versions of jax[cuda] to determine which version is compatible with other requirements. This could take a while. Collecting jax[cuda] Using cached jax-0.4.37-py3-none-any.whl.metadata (22 kB) Collecting jaxlib<=0.4.37,>=0.4.36 (from jax[cuda]) Using cached jaxlib-0.4.36-cp310-cp310-macosx_11_0_arm64.whl.metadata (1.0 kB) Collecting jax[cuda] Using cached jax-0.4.36-py3-none-any.whl.metadata (22 kB) Using cached jax-0.4.35-py3-none-any.whl.metadata (22 kB) Collecting jaxlib<=0.4.35,>=0.4.34 (from jax[cuda]) Using cached jaxlib-0.4.35-cp310-cp310-macosx_11_0_arm64.whl.metadata (983 bytes) Using cached jaxlib-0.4.34-cp310-cp310-macosx_11_0_arm64.whl.metadata (983 bytes) Collecting jax[cuda] Using cached jax-0.4.34-py3-none-any.whl.metadata (22 kB) Using cached jax-0.4.33-py3-none-any.whl.metadata (22 kB) Collecting jaxlib<=0.4.33,>=0.4.33 (from jax[cuda]) Using cached jaxlib-0.4.33-cp310-cp310-macosx_11_0_arm64.whl.metadata (983 bytes) INFO: This is taking longer than usual. You might need to provide the dependency resolver with stricter constraints to reduce runtime. See https://pip.pypa.io/warnings/backtracking for guidance. If you want to abort this run, press Ctrl + C. Collecting jax[cuda] Using cached jax-0.4.31-py3-none-any.whl.metadata (22 kB) Collecting jaxlib<=0.4.31,>=0.4.30 (from jax[cuda]) Using cached jaxlib-0.4.31-cp310-cp310-macosx_11_0_arm64.whl.metadata (983 bytes) Collecting jax[cuda] Using cached jax-0.4.30-py3-none-any.whl.metadata (22 kB) Collecting jaxlib<=0.4.30,>=0.4.27 (from jax[cuda]) Using cached jaxlib-0.4.30-cp310-cp310-macosx_11_0_arm64.whl.metadata (1.0 kB) Collecting jax[cuda] Using cached jax-0.4.29-py3-none-any.whl.metadata (23 kB) Using cached jax-0.4.28-py3-none-any.whl.metadata (23 kB) Using cached jax-0.4.27-py3-none-any.whl.metadata (23 kB) Using cached jax-0.4.26-py3-none-any.whl.metadata (23 kB) Using cached jax-0.4.25-py3-none-any.whl.metadata (24 kB) Using cached jax-0.4.24-py3-none-any.whl.metadata (24 kB) Using cached jax-0.4.23-py3-none-any.whl.metadata (24 kB) Using cached jax-0.4.22-py3-none-any.whl.metadata (24 kB) Using cached jax-0.4.21-py3-none-any.whl.metadata (23 kB) Using cached jax-0.4.20-py3-none-any.whl.metadata (23 kB) Using cached jax-0.4.19-py3-none-any.whl.metadata (23 kB) Using cached jax-0.4.18-py3-none-any.whl.metadata (23 kB) Using cached jax-0.4.17-py3-none-any.whl.metadata (23 kB) Using cached jax-0.4.16-py3-none-any.whl.metadata (29 kB) Using cached jax-0.4.14.tar.gz (1.3 MB) Installing build dependencies ... done Getting requirements to build wheel ... done Preparing metadata (pyproject.toml) ... done Using cached jax-0.4.13.tar.gz (1.3 MB) Installing build dependencies ... done Getting requirements to build wheel ... done Preparing metadata (pyproject.toml) ... done Using cached jax-0.4.12.tar.gz (1.3 MB) Installing build dependencies ... done Getting requirements to build wheel ... done Preparing metadata (pyproject.toml) ... done Using cached jax-0.4.11.tar.gz (1.3 MB) Installing build dependencies ... done Getting requirements to build wheel ... done Preparing metadata (pyproject.toml) ... done Using cached jax-0.4.10.tar.gz (1.3 MB) Installing build dependencies ... done Getting requirements to build wheel ... done Preparing metadata (pyproject.toml) ... done Using cached jax-0.4.9.tar.gz (1.3 MB) Installing build dependencies ... done Getting requirements to build wheel ... done Preparing metadata (pyproject.toml) ... done Using cached jax-0.4.8.tar.gz (1.2 MB) Installing build dependencies ... done Getting requirements to build wheel ... done Preparing metadata (pyproject.toml) ... done Using cached jax-0.4.7.tar.gz (1.2 MB) Preparing metadata (setup.py) ... done Using cached jax-0.4.6.tar.gz (1.2 MB) Preparing metadata (setup.py) ... done Using cached jax-0.4.5.tar.gz (1.2 MB) Preparing metadata (setup.py) ... done Using cached jax-0.4.4.tar.gz (1.2 MB) Preparing metadata (setup.py) ... done Using cached jax-0.4.3.tar.gz (1.2 MB) Preparing metadata (setup.py) ... done Using cached jax-0.4.2.tar.gz (1.2 MB) Preparing metadata (setup.py) ... done Using cached jax-0.4.1.tar.gz (1.2 MB) Preparing metadata (setup.py) ... done Using cached jax-0.3.25.tar.gz (1.1 MB) Preparing metadata (setup.py) ... done Requirement already satisfied: typing_extensions in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax[cuda]) (4.15.0) Using cached jax-0.3.24.tar.gz (1.1 MB) Preparing metadata (setup.py) ... done Using cached jax-0.3.23.tar.gz (1.1 MB) Preparing metadata (setup.py) ... done Requirement already satisfied: absl-py in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax[cuda]) (2.3.1) Requirement already satisfied: etils[epath] in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax[cuda]) (1.13.0) Using cached jax-0.3.22.tar.gz (1.1 MB) Preparing metadata (setup.py) ... done Using cached jax-0.3.21.tar.gz (1.1 MB) Preparing metadata (setup.py) ... done Using cached jax-0.3.20.tar.gz (1.1 MB) Preparing metadata (setup.py) ... done Using cached jax-0.3.19.tar.gz (1.1 MB) Preparing metadata (setup.py) ... done Using cached jax-0.3.17.tar.gz (1.1 MB) Preparing metadata (setup.py) ... done Using cached jax-0.3.16.tar.gz (1.0 MB) Preparing metadata (setup.py) ... done Using cached jax-0.3.15.tar.gz (1.0 MB) Preparing metadata (setup.py) ... done Using cached jax-0.3.14.tar.gz (990 kB) Preparing metadata (setup.py) ... done Using cached jax-0.3.13.tar.gz (951 kB) Preparing metadata (setup.py) ... done Using cached jax-0.3.12.tar.gz (947 kB) Preparing metadata (setup.py) ... done Using cached jax-0.3.11.tar.gz (947 kB) Preparing metadata (setup.py) ... done Using cached jax-0.3.10.tar.gz (939 kB) Preparing metadata (setup.py) ... done Using cached jax-0.3.9.tar.gz (937 kB) Preparing metadata (setup.py) ... done Using cached jax-0.3.8.tar.gz (935 kB) Preparing metadata (setup.py) ... done Using cached jax-0.3.7.tar.gz (944 kB) Preparing metadata (setup.py) ... done Using cached jax-0.3.6.tar.gz (936 kB) Preparing metadata (setup.py) ... done Using cached jax-0.3.5.tar.gz (946 kB) Preparing metadata (setup.py) ... done Using cached jax-0.3.4.tar.gz (924 kB) Preparing metadata (setup.py) ... done Using cached jax-0.3.3.tar.gz (924 kB) Preparing metadata (setup.py) ... done Using cached jax-0.3.2.tar.gz (926 kB) Preparing metadata (setup.py) ... done Using cached jax-0.3.1.tar.gz (912 kB) Preparing metadata (setup.py) ... done Using cached jax-0.3.0.tar.gz (896 kB) Preparing metadata (setup.py) ... done Using cached jax-0.2.28.tar.gz (887 kB) Preparing metadata (setup.py) ... done Using cached jax-0.2.27.tar.gz (873 kB) Preparing metadata (setup.py) ... done Using cached jax-0.2.26.tar.gz (850 kB) Preparing metadata (setup.py) ... done Using cached jax-0.2.25.tar.gz (786 kB) Preparing metadata (setup.py) ... done Using cached jax-0.2.24.tar.gz (786 kB) Preparing metadata (setup.py) ... done Using cached jax-0.2.22-py3-none-any.whl WARNING: jax 0.2.22 does not provide the extra 'cuda' Installing collected packages: jax Attempting uninstall: jax Found existing installation: jax 0.6.2 Uninstalling jax-0.6.2: Successfully uninstalled jax-0.6.2 ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. chex 0.1.90 requires jax>=0.4.27, but you have jax 0.2.22 which is incompatible. flax 0.10.7 requires jax>=0.6.0, but you have jax 0.2.22 which is incompatible. optax 0.2.6 requires jax>=0.5.3, but you have jax 0.2.22 which is incompatible. orbax-checkpoint 0.11.25 requires jax>=0.6.0, but you have jax 0.2.22 which is incompatible. Successfully installed jax-0.2.22 Collecting git+https://github.com/n2cholas/jax-resnet.git Cloning https://github.com/n2cholas/jax-resnet.git to /private/var/folders/mj/bwns80rs3psccqv0gz4hh8zc0000gn/T/pip-req-build-8vvkz9sb Running command git clone --filter=blob:none --quiet https://github.com/n2cholas/jax-resnet.git /private/var/folders/mj/bwns80rs3psccqv0gz4hh8zc0000gn/T/pip-req-build-8vvkz9sb Resolved https://github.com/n2cholas/jax-resnet.git to commit 5b00735aa0a68ec239af4a728ad4a596c1b551f6 Preparing metadata (setup.py) ... done Requirement already satisfied: flax in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (0.10.7) Requirement already satisfied: optax in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (0.2.6) Requirement already satisfied: tensorflow-datasets in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (4.9.9) Requirement already satisfied: better_exceptions in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (0.3.3) Requirement already satisfied: jax in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax-resnet==0.0.4) (0.2.22) Requirement already satisfied: jaxlib in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax-resnet==0.0.4) (0.6.2) Collecting jax (from jax-resnet==0.0.4) Using cached jax-0.6.2-py3-none-any.whl.metadata (13 kB) Requirement already satisfied: msgpack in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from flax) (1.1.1) Requirement already satisfied: orbax-checkpoint in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from flax) (0.11.25) Requirement already satisfied: tensorstore in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from flax) (0.1.77) Requirement already satisfied: rich>=11.1 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from flax) (14.1.0) Requirement already satisfied: typing_extensions>=4.2 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from flax) (4.15.0) Requirement already satisfied: PyYAML>=5.4.1 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from flax) (6.0.3) Requirement already satisfied: treescope>=0.1.7 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from flax) (0.1.10) Requirement already satisfied: absl-py>=0.7.1 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from optax) (2.3.1) Requirement already satisfied: chex>=0.1.87 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from optax) (0.1.90) Requirement already satisfied: numpy>=1.18.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from optax) (1.26.4) Requirement already satisfied: dm-tree in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (0.1.9) Requirement already satisfied: etils>=1.6.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from etils[edc,enp,epath,epy,etree]>=1.6.0; python_version < "3.11"->tensorflow-datasets) (1.13.0) Requirement already satisfied: immutabledict in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (4.2.1) Requirement already satisfied: promise in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (2.3) Requirement already satisfied: protobuf>=3.20 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (4.21.12) Requirement already satisfied: psutil in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (7.1.0) Requirement already satisfied: pyarrow in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (21.0.0) Requirement already satisfied: requests>=2.19.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (2.32.5) Requirement already satisfied: simple_parsing in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (0.1.7) Requirement already satisfied: tensorflow-metadata in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (1.17.2) Requirement already satisfied: termcolor in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (3.1.0) Requirement already satisfied: toml in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (0.10.2) Requirement already satisfied: tqdm in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (4.67.1) Requirement already satisfied: wrapt in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (1.17.3) Requirement already satisfied: toolz>=0.9.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from chex>=0.1.87->optax) (1.0.0) Requirement already satisfied: fsspec in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from etils[edc,enp,epath,epy,etree]>=1.6.0; python_version < "3.11"->tensorflow-datasets) (2025.9.0) Requirement already satisfied: importlib_resources in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from etils[edc,enp,epath,epy,etree]>=1.6.0; python_version < "3.11"->tensorflow-datasets) (6.5.2) Requirement already satisfied: zipp in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from etils[edc,enp,epath,epy,etree]>=1.6.0; python_version < "3.11"->tensorflow-datasets) (3.23.0) Requirement already satisfied: einops in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from etils[edc,enp,epath,epy,etree]>=1.6.0; python_version < "3.11"->tensorflow-datasets) (0.8.1) Requirement already satisfied: ml_dtypes>=0.5.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax->jax-resnet==0.0.4) (0.5.3) Requirement already satisfied: opt_einsum in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax->jax-resnet==0.0.4) (3.4.0) Requirement already satisfied: scipy>=1.12 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax->jax-resnet==0.0.4) (1.15.3) Requirement already satisfied: charset_normalizer<4,>=2 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from requests>=2.19.0->tensorflow-datasets) (3.4.3) Requirement already satisfied: idna<4,>=2.5 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from requests>=2.19.0->tensorflow-datasets) (3.10) Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from requests>=2.19.0->tensorflow-datasets) (2.5.0) Requirement already satisfied: certifi>=2017.4.17 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from requests>=2.19.0->tensorflow-datasets) (2025.8.3) Requirement already satisfied: markdown-it-py>=2.2.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from rich>=11.1->flax) (4.0.0) Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from rich>=11.1->flax) (2.19.2) Requirement already satisfied: mdurl~=0.1 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich>=11.1->flax) (0.1.2) Requirement already satisfied: attrs>=18.2.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from dm-tree->tensorflow-datasets) (25.3.0) Requirement already satisfied: nest_asyncio in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from orbax-checkpoint->flax) (1.6.0) Requirement already satisfied: aiofiles in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from orbax-checkpoint->flax) (24.1.0) Requirement already satisfied: humanize in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from orbax-checkpoint->flax) (4.13.0) Requirement already satisfied: simplejson>=3.16.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from orbax-checkpoint->flax) (3.20.2) Requirement already satisfied: six in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from promise->tensorflow-datasets) (1.17.0) Requirement already satisfied: docstring-parser<1.0,>=0.15 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from simple_parsing->tensorflow-datasets) (0.17.0) Using cached jax-0.6.2-py3-none-any.whl (2.7 MB) Installing collected packages: jax Attempting uninstall: jax Found existing installation: jax 0.2.22 Uninstalling jax-0.2.22: Successfully uninstalled jax-0.2.22 Successfully installed jax-0.6.2
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1"
import tensorflow as tf
tf.get_logger().setLevel("WARNING")
tf.config.experimental.set_visible_devices([], "GPU")
import json
from functools import partial
from pathlib import Path
import flax
import flax.core
import flax.linen as nn
import jax
import jax.numpy as jnp
import jax_resnet
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_datasets as tfds
import torch
import tqdm
from IPython.display import Markdown, display
Utils¶
CLASS_NAMES = (
"tench",
"English springer",
"cassette player",
"chain saw",
"church",
"French horn",
"garbage truck",
"gas pump",
"golf ball",
"parachute",
)
BIAS_COLORS = [
[0.1215, 0.4666, 0.7058],
[0.5490, 0.3372, 0.2941],
[1.0000, 0.4980, 0.0549],
[0.1725, 0.6274, 0.1725],
[0.8392, 0.1529, 0.1568],
[0.5803, 0.4039, 0.7411],
[0.8901, 0.4666, 0.7607],
[0.7372, 0.7411, 0.1333],
[0.0901, 0.7450, 0.8117],
[0.4980, 0.4980, 0.4980],
]
RED = np.array([1.0, 0, 0])
BLUE = np.array([0, 0, 1.0])
def create_dataset(data_dir: str, batch_size: int):
ds_builder = tfds.builder("imagenette/320px-v2", data_dir=data_dir)
ds_builder.download_and_prepare()
ds_val = ds_builder.as_dataset("validation", as_supervised=True)
ds_val = ds_val.map(resize)
ds_val = ds_val.map(add_bias_pixel)
ds_val = ds_val.batch(batch_size)
ds_val = tfds.as_numpy(ds_val)
return ds_val
def resize(image, label):
image = tf.image.resize_with_pad(image, 224, 224)
return image / 255.0, label
def add_bias_pixel(image, label):
hw_ = tf.reduce_sum(image, axis=[0, 1])
hw_ = tf.cast(hw_, tf.int32) % 30 + 140
h = hw_[0]
w = hw_[1]
color = tf.constant(BIAS_COLORS)[label]
mask = tf.meshgrid(tf.range(224), tf.range(224), indexing="ij")
mask = (
(mask[0] % 12 != tf.cast(label, tf.int32) + 1)
& (mask[0] > h)
& (mask[0] < h + 12)
& (mask[1] % 5 < 2)
& (mask[1] > w)
& (mask[1] < w + 30)
)
image = tf.where(mask[:, :, None], color, image)
return image, label
def load_checkpoint(path):
@jax.jit
def logits_fn(variables, img):
# img: [H, W, C], float32 in range [0, 1]
assert img.ndim == 3
img = normalize_for_resnet(img)
logits = model.apply(variables, img[None, ...], mutable=False)[0]
return logits.max(), logits
path = Path(path)
args = json.loads(Path.read_text(path / "args.json"))
variables_path = path / "variables.npy"
model = getattr(jax_resnet.resnet, f"ResNet{args['resnet_size']}")(n_classes=10)
variables = model.init(jax.random.PRNGKey(0), jnp.zeros((1, 224, 224, 3)))
variables = flax.serialization.from_bytes(variables, variables_path.read_bytes())
return logits_fn, variables
def normalize_for_resnet(images):
# images: [..., H, W, 3], float32, range [0, 1]
mean = jnp.array([0.485, 0.456, 0.406])
std = jnp.array([0.229, 0.224, 0.225])
return (images - mean) / std
def imagenet_to_imagenette_logits(logits):
"""Select the 10 imagenette classes from the 1000 imagenet classes."""
return logits[..., [0, 217, 482, 491, 497, 566, 569, 571, 574, 701]]
def show_images(images, labels=None, logits=None, ncols=4, width_one_img_inch=3.0):
B, H, W, *_ = images.shape
nrows = int(np.ceil(B / ncols))
fig, axs = plt.subplots(
nrows,
ncols,
figsize=width_one_img_inch * np.array([1, H / W]) * np.array([ncols, nrows]),
sharex=True,
sharey=True,
squeeze=False,
facecolor="white",
)
for b in range(B):
ax = axs.flat[b]
ax.imshow(images[b])
if labels is not None:
ax.set_title(CLASS_NAMES[labels[b]])
if logits is not None:
pred = logits[b].argmax()
prob = jax.nn.softmax(logits[b])[pred]
color = (
"blue" if labels is None else ("green" if labels[b] == pred else "red")
)
p = mpl.patches.Patch(color=color, label=f"{prob:.2%} {CLASS_NAMES[pred]}")
ax.legend(handles=[p])
fig.tight_layout()
display(fig)
plt.close(fig)
@jax.jit
def blend(a, b, alpha: float):
return (1 - alpha) * a + alpha * b
Metrics¶
These tables summarize the hyperparameters used to train the models and their performance.
Accuracy and loss are reported for two slightly different versions of the validation set: one that contains a clear source of bias and one that doesn't. If we only had access to the biased dataset and we did not know about the bias, we might be tempted to choose the first model, which achieves a much higher accuracy than the second.
paths = [Path('output/biased'), Path('output/unbiased')]
df_args = (
pd.DataFrame([json.loads(Path.read_text(p / "args.json")) for p in paths])
.drop(columns="output")
.set_index("run_id")
.sort_index()
)
display(df_args)
df_test = pd.DataFrame(
[
{"run_id": p.parent.name, **json.loads(line)}
for p in paths
for line in Path.read_text(p / "test.json").splitlines()
],
)
display(
df_test.pivot_table(
index="run_id", columns="bias_pixel", values=["accuracy", "loss"]
)
.sort_index()
.style.format("{:.3f}")
.format("{:.1%}", subset="accuracy")
)
| bias_pixel | resnet_size | epochs | seed | learning_rate | weight_decay | batch_size | |
|---|---|---|---|---|---|---|---|
| run_id | |||||||
| biased | True | 18 | 10 | 5807 | 0.001 | 0.0001 | 64 |
| unbiased | False | 18 | 10 | 5807 | 0.001 | 0.0001 | 64 |
| accuracy | loss | |||
|---|---|---|---|---|
| bias_pixel | False | True | False | True |
| run_id | ||||
| output | 48.7% | 85.2% | 2.774 | 0.551 |
Model comparison¶
Task 1¶
Reimplement one of the explanation methods from the previous notebook and use it to visualize the most important regions for the first few batches of images.
- Can you spot the model that was trained on biased data?
- Which explanation method did you choose? Can you motivate your choice? Did you try others to see what worked best?
- Can you summarize the explanation method and suggest why it works best here?
Add your comments below:
- Yes, the Model B is trained on biased data. It incorrectly predicts a chain saw as a cassette player in Batch 0 and a gas pump as Church in Batch 2, showing that it relies on spurious features rather than the actual object.
- I have chosen the Integrated Gradients explanation method as it had the lowest deletion score from previous work and was concluded to the best method.
- I have also implemented Occlusion, Grad_X_Input and Grad_CAM method to compare the results. I see that Grad_CAM results in the same predictions for both the models. This reveals the shortcut learning of Grad_CAM as it activates similar feature map regions for learning the object and the artifact. Grad_X_Input method on the other hand predicts identically to Integrated Gradients suggesting it captures similar pixel-level relevance but with more noise. Occlusion method also aligned with Integrated Gradients. Even though it does not depend on gradients, Occlusion method was able to expose bias fairly well as it directly tests the causal impact of masking image regions.
- Integrated Gradients explanation method highlights pixels that consistently increase the prediction confidence as they are added to the image. It is the best at exposing bias, as it directly shows pixel-level causal attributions, making it clear whether the model relies on the object or artifact.
Note: explanation_fn will be called with logits_fn, variables, and images. Extra parameters can be put in kwargs and partial will take care of them.
@partial(jax.jit, static_argnames=["steps"])
def prepare_integrated_gradients(img, steps: int):
assert img.ndim == 3
return img[None, :, :, :] * jnp.linspace(1, 0, num=steps)[:, None, None, None]
@jax.jit
def normalize_max(x):
"""Normalize a vector between -1 and 1."""
res = x / jnp.abs(x).max()
res = jnp.clip(res, a_min=-1, a_max=1)
return res
def integrated_grad_fn(logits_fn, variables, img, steps: int):
H, W, _ = img.shape
# model's predicted class
_, logits_orig = logits_fn(variables, img)
idx = logits_orig.argmax()
baseline = jnp.zeros_like(img)
images = prepare_integrated_gradients(img, steps).reshape(-1, H, W, 3)
_, logits = jax.vmap(logits_fn, (None, 0))(variables, images)
# function to call grad on idx-th element of logits
def grads_idx_fn(variables, img_):
logit_max, logit = logits_fn(variables, img_)
val = logit[idx]
return val, logit_max
value_and_grad_fn = jax.value_and_grad(grads_idx_fn, argnums=1, has_aux=True)
(_,_), grads = jax.vmap(lambda im: value_and_grad_fn(variables, im), in_axes=0)(images)
avg_grads = grads.mean(axis=0)
ig = (img - baseline) * avg_grads
heat = jnp.linalg.norm(ig, axis=-1)
grads = normalize_max(heat)
# logits: [num_classes]
# grads: [H, W]
return logits_orig, grads
def explanation_fn(logits_fn, variables, img):
H, W, _ = img.shape
logits, attrib = integrated_grad_fn(logits_fn, variables, img, 25)
# logits: [num_classes]
# attrib: [H, W]
return logits, attrib
kwargs = {}
explanation_fn = partial(explanation_fn, **kwargs)
explanation_fn = jax.vmap(explanation_fn, in_axes=(None, None, 0))
explanation_fn = jax.jit(explanation_fn, static_argnames=["logits_fn"])
### Integrated Gradients method
ds_val = create_dataset(".", batch_size=4)
logits_fn_a, variables_a = load_checkpoint("output/biased")
logits_fn_b, variables_b = load_checkpoint("output/unbiased")
for batch_idx, (images, labels) in enumerate(ds_val):
display(Markdown(f"## Batch {batch_idx}"))
display(Markdown(f"### Model A"))
logits, relevance = explanation_fn(logits_fn_a, variables_a, images)
show_images(
blend(images, RED, relevance.clip(min=0)[..., None]),
labels,
logits,
)
display(Markdown(f"### Model B"))
logits, relevance = explanation_fn(logits_fn_b, variables_b, images)
show_images(
blend(images, RED, relevance.clip(min=0)[..., None]),
labels,
logits,
)
if batch_idx >= 2:
break
def prepare_occlusions(img, steps: int):
H, W, _ = img.shape
imgs = jnp.tile(img, (steps, steps, 1, 1, 1))
for i in range(0, steps):
for j in range(0, steps):
imgs = imgs.at[i, j, int(i*H/steps):int((i+1)*H/steps), int(j*W/steps):int((j+1)*W/steps), :].set(0)
# imgs: [steps, steps, H, W, 3]
return imgs
def occlusion_fn(logits_fn, variables, img, steps: int):
H, W, _ = img.shape
_, logits_orig = logits_fn(variables, img)
probs = nn.softmax(logits_orig)
idx = logits_orig.argmax()
imgs = prepare_occlusions(img, steps)
logits_occ_fn = jax.vmap(
jax.vmap(logits_fn, (None,0)),
(None,0)
)
_, logits_occ = logits_occ_fn(variables, imgs)
probs_occ = nn.softmax(logits_occ, axis=-1)
relevance = probs[idx] - probs_occ[..., idx]
relevance = jax.image.resize(relevance, (H, W), method="bilinear")
attrib = normalize_max(relevance)
# logits_orig: [num_classes]
# attrib: [H, W]
return logits_orig, attrib
def explanation_fn(logits_fn, variables, img):
H, W, _ = img.shape
steps = 15
logits, attrib = occlusion_fn(logits_fn, variables, img, steps)
# logits: [num_classes]
# attrib: [H, W]
return logits, attrib
kwargs = {}
explanation_fn = partial(explanation_fn, **kwargs)
explanation_fn = jax.vmap(explanation_fn, in_axes=(None, None, 0))
explanation_fn = jax.jit(explanation_fn, static_argnames=["logits_fn"])
### Occlusion method
ds_val = create_dataset(".", batch_size=4)
logits_fn_a, variables_a = load_checkpoint("output/biased")
logits_fn_b, variables_b = load_checkpoint("output/unbiased")
for batch_idx, (images, labels) in enumerate(ds_val):
display(Markdown(f"## Batch {batch_idx}"))
display(Markdown(f"### Model A"))
logits, relevance = explanation_fn(logits_fn_a, variables_a, images)
show_images(
blend(images, RED, relevance.clip(min=0)[..., None]),
labels,
logits,
)
display(Markdown(f"### Model B"))
logits, relevance = explanation_fn(logits_fn_b, variables_b, images)
show_images(
blend(images, RED, relevance.clip(min=0)[..., None]),
labels,
logits,
)
if batch_idx >= 2:
break
def grad_x_input_fn(logits_fn, variables, img):
H, W, _ = img.shape
logits_vg_fn = jax.value_and_grad(logits_fn, argnums=1, has_aux=True)
(_, logits), grads = logits_vg_fn(variables, img)
grads_x = img * grads
heat_x = jnp.linalg.norm(grads_x, axis=-1)
grad = normalize_max(heat_x)
# logits: [num_classes]
# grad: [H, W]
return logits, grad
def explanation_fn(logits_fn, variables, img):
H, W, _ = img.shape
logits, attrib = grad_x_input_fn(logits_fn, variables, img)
# logits: [num_classes]
# attrib: [H, W]
return logits, attrib
kwargs = {}
explanation_fn = partial(explanation_fn, **kwargs)
explanation_fn = jax.vmap(explanation_fn, in_axes=(None, None, 0))
explanation_fn = jax.jit(explanation_fn, static_argnames=["logits_fn"])
### Grad_X_Input method
ds_val = create_dataset(".", batch_size=4)
logits_fn_a, variables_a = load_checkpoint("output/biased")
logits_fn_b, variables_b = load_checkpoint("output/unbiased")
for batch_idx, (images, labels) in enumerate(ds_val):
display(Markdown(f"## Batch {batch_idx}"))
display(Markdown(f"### Model A"))
logits, relevance = explanation_fn(logits_fn_a, variables_a, images)
show_images(
blend(images, RED, relevance.clip(min=0)[..., None]),
labels,
logits,
)
display(Markdown(f"### Model B"))
logits, relevance = explanation_fn(logits_fn_b, variables_b, images)
show_images(
blend(images, RED, relevance.clip(min=0)[..., None]),
labels,
logits,
)
if batch_idx >= 2:
break
def grad_cam_fn(fns, variables, img):
H, W, _ = img.shape
backbone_fn = fns["backbone"]
gap_classifier_fn = fns["gap_cls"]
backbone_vars = variables["backbone"]
gap_classifier_vars = variables["gap_cls"]
# apply image through backbone
features = backbone_fn(backbone_vars, img)
_, logits = gap_classifier_fn(gap_classifier_vars, features)
# fix the target class c
c = jnp.argmax(logits)
# scalar logit function for class c
def class_logit_fn(vars_, feats_):
_, logits = gap_classifier_fn(vars_, feats_)
# scalar Y^c
return logits[c]
# gradients wrt features for class c
vgf = jax.value_and_grad(class_logit_fn, argnums=1, has_aux=False)
_, grads = vgf(gap_classifier_vars, features)
alpha = grads.mean(axis=(0,1))
relevance = jnp.einsum("hwc,c->hw", features, alpha)
relevance = jnp.maximum(relevance, 0)
#print("relevance:", relevance.shape)
# resize to input image size
relevance_resized = jax.image.resize(
relevance, (H, W), method="bilinear"
)
relevance_resized = normalize_max(relevance_resized)
#print("relevance_resized:", relevance_resized.shape)
# logits: [num_classes]
# grad: [H, W]
return logits, relevance_resized
def load_resnet_for_grad_cam(size):
@jax.jit
def backbone_fn(variables, img):
# img: [H, W, C], float32 in range [0, 1]
# feats: [h, w, c], float32
img = normalize_for_resnet(img)
feats = backbone.apply(variables, img[None, ...], mutable=False)[0]
return feats
@jax.jit
def gap_classifier_fn(variables, feats):
# feats: [h, w, c], float32
# logit: float32
# logits: [10], float32
logits = gap_classifier.apply(variables, feats[None, ...], mutable=False)[0]
logits = imagenet_to_imagenette_logits(logits)
return logits.max(), logits
ResNet, variables = jax_resnet.pretrained_resnet(size)
model = ResNet()
backbone = nn.Sequential(model.layers[:-2])
backbone_vars = jax_resnet.slice_variables(variables, start=0, end=-2)
gap_classifier = nn.Sequential(model.layers[-2:])
gap_classifier_vars = jax_resnet.slice_variables(variables, start=len(model.layers) - 2, end=None)
return (
flax.core.freeze({"backbone": backbone_fn, "gap_cls": gap_classifier_fn}),
flax.core.freeze({"backbone": backbone_vars, "gap_cls": gap_classifier_vars}),
)
def explanation_fn(logits_fn, variables, img):
H, W, _ = img.shape
fns, variables = load_resnet_for_grad_cam(size=18)
logits, attrib = grad_cam_fn(fns, variables, img)
# logits: [num_classes]
# attrib: [H, W]
return logits, attrib
kwargs = {}
explanation_fn = partial(explanation_fn, **kwargs)
explanation_fn = jax.vmap(explanation_fn, in_axes=(None, None, 0))
explanation_fn = jax.jit(explanation_fn, static_argnames=["logits_fn"])
### Grad-CAM method
ds_val = create_dataset(".", batch_size=4)
logits_fn_a, variables_a = load_checkpoint("output/biased")
logits_fn_b, variables_b = load_checkpoint("output/unbiased")
for batch_idx, (images, labels) in enumerate(ds_val):
display(Markdown(f"## Batch {batch_idx}"))
display(Markdown(f"### Model A"))
logits, relevance = explanation_fn(logits_fn_a, variables_a, images)
show_images(
blend(images, RED, relevance.clip(min=0)[..., None]),
labels,
logits,
)
display(Markdown(f"### Model B"))
logits, relevance = explanation_fn(logits_fn_b, variables_b, images)
show_images(
blend(images, RED, relevance.clip(min=0)[..., None]),
labels,
logits,
)
if batch_idx >= 2:
break
Task 2¶
How long did it take you to complete this practical? This information is valuable to us to balance the difficulty of different practicals. 1.5 hour